{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import pandas as pd\n",
    "import jieba\n",
    "import os\n",
    "from torch.nn import init\n",
    "from torchtext import data\n",
    "from torchtext.vocab import Vectors\n",
    "import time\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# data processing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 分词\n",
    "def tokenizer(text): \n",
    "    return [word for word in jieba.lcut(text) if word not in stop_words]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 去停用词\n",
    "def get_stop_words():\n",
    "    file_object = open('data/stopwords.txt',encoding='utf-8')\n",
    "    stop_words = []\n",
    "    for line in file_object.readlines():\n",
    "        line = line[:-1]\n",
    "        line = line.strip()\n",
    "        stop_words.append(line)\n",
    "    return stop_words\n",
    "\n",
    "stop_words = get_stop_words()  # 加载停用词表"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "text = data.Field(sequential=True,\n",
    "                  lower=True,\n",
    "                  tokenize=tokenizer,\n",
    "                  stop_words=stop_words)\n",
    "label = data.Field(sequential=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Building prefix dict from the default dictionary ...\n",
      "Loading model from cache C:\\Users\\ADMINI~1\\AppData\\Local\\Temp\\jieba.cache\n",
      "Loading model cost 1.677 seconds.\n",
      "Prefix dict has been built successfully.\n"
     ]
    }
   ],
   "source": [
    "train, val = data.TabularDataset.splits(\n",
    "    path='data/',\n",
    "    skip_header=True,\n",
    "    train='train.tsv',\n",
    "    validation='validation.tsv',\n",
    "    format='tsv',\n",
    "    fields=[('index', None), ('label', label), ('text', text)],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['油耗', '显示', '13', '升', '多一点', '希望', '慢慢', '下降', '倒车', '雷达', '真', '可恨']\n",
      "dict_keys(['label', 'text'])\n"
     ]
    }
   ],
   "source": [
    "print(train[2].text)\n",
    "print(train[5].__dict__.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.6281084  -0.99839437 -1.352674   -0.5229339  -0.66837466  0.12023606\n",
      "  0.5181639  -0.441484    1.3996645   1.6288378   1.0222958  -0.48031467\n",
      "  1.0719501  -0.19752608  0.5327201   0.94984937  1.2358193  -0.11119507\n",
      "  0.41765276 -1.8168654  -0.79406536 -0.2047712   1.799883   -1.0999346\n",
      "  0.13697413 -0.8517842   1.3984505  -0.79764915 -0.07896331  1.2744272\n",
      " -0.83588195  0.08194309  1.3844757   0.48616973  2.6387873   0.92588806\n",
      " -1.4051545   1.3981075  -1.3735238  -0.66897595 -0.68339837  2.117681\n",
      " -0.38238457  0.6941145  -1.1269426  -0.56344277  1.2109338   0.66243315\n",
      "  1.0955988  -0.09703278 -0.66029483 -0.50386393 -2.1538718  -0.76798207\n",
      " -0.18256703 -0.72396874 -0.8715642   1.7946631   0.04120466 -1.8321133\n",
      "  0.5928513   0.0921784   0.49934524  0.18206167 -1.4094074   0.06243325\n",
      " -1.4292387   1.1573641  -0.349081   -0.12080618  0.39978337  1.5544038\n",
      "  0.91382134 -1.4675195   0.5610765  -0.37099078  0.594345    0.38196605\n",
      " -1.6270672   0.6772313   1.5092831   0.30081615  1.257672    0.29563576\n",
      "  1.0354363  -0.9241496  -0.30267137 -2.2157154  -1.4336832   0.21979204\n",
      " -1.7822541  -0.897993    0.74574155 -0.8866193   1.1431063  -0.46215868\n",
      "  2.0792334   0.54710203  0.44141322 -0.00408479]\n"
     ]
    }
   ],
   "source": [
    "import gensim\n",
    "model = gensim.models.KeyedVectors.load_word2vec_format('data/myvector.vector', binary=False)\n",
    "print(model['空调'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "cache = 'data/.vector_cache'\n",
    "if not os.path.exists(cache):\n",
    "    os.mkdir(cache)\n",
    "vectors = Vectors(name='data/myvector.vector', cache=cache)\n",
    "# 指定Vector缺失值的初始化方式，没有命中的token的初始化方式\n",
    "#vectors.unk_init = nn.init.xavier_uniform_\n",
    "\n",
    "text.build_vocab(train, val, vectors=vectors)#加入测试集的vertor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "#text.build_vocab(train, val, vectors=Vectors(name='data/myvector.vector'))#加入测试集的vertor\n",
    "label.build_vocab(train, val)\n",
    "\n",
    "embedding_dim = text.vocab.vectors.size()[-1]\n",
    "vectors = text.vocab.vectors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('空间', 13734),\n",
       " ('外观', 10641),\n",
       " ('满意', 9777),\n",
       " ('车', 8260),\n",
       " ('动力', 7856),\n",
       " ('油耗', 7540),\n",
       " ('高', 6119),\n",
       " ('内饰', 6068),\n",
       " ('感觉', 5383),\n",
       " ('配置', 4596)]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([34841, 100])\n"
     ]
    }
   ],
   "source": [
    "text.vocab.freqs.most_common(10)\n",
    "print(text.vocab.vectors.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size=128\n",
    "train_iter, val_iter = data.Iterator.splits(\n",
    "            (train, val),\n",
    "            sort_key=lambda x: len(x.text),\n",
    "            batch_sizes=(batch_size, len(val)), # 训练集设置batch_size,验证集整个集合用于测试\n",
    "    )\n",
    "\n",
    "vocab_size = len(text.vocab)\n",
    "label_num = len(label.vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([34, 128])\n",
      "tensor([[  20,   60,    6,  ...,    3, 1007,  753],\n",
      "        [   8,  127, 1365,  ...,   46,   12,  139],\n",
      "        [   3,  171,  242,  ...,  250,    3,    2],\n",
      "        ...,\n",
      "        [   1,    1,    1,  ...,    1,    1,    1],\n",
      "        [   1,    1,    1,  ...,    1,    1,    1],\n",
      "        [   1,    1,    1,  ...,    1,    1,    1]])\n"
     ]
    }
   ],
   "source": [
    "batch = next(iter(train_iter))\n",
    "data = batch.text\n",
    "print(batch.text.shape)\n",
    "print(batch.text)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GlobalMaxPool1d(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(GlobalMaxPool1d, self).__init__()\n",
    "    def forward(self, x):\n",
    "         # x shape: (batch_size, channel, seq_len)\n",
    "        return F.max_pool1d(x, kernel_size=x.shape[2]) # shape: (batch_size, channel, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TextCNN(nn.Module):\n",
    "    def __init__(self, vocab_size, embedding_dim, kernel_sizes, num_channels):\n",
    "        super(TextCNN, self).__init__()\n",
    "        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)  # embedding之后的shape: torch.Size([200, 8, 300])\n",
    "        self.word_embeddings = self.word_embeddings.from_pretrained(vectors, freeze=False)\n",
    "        self.dropout = nn.Dropout(0.5)\n",
    "        self.decoder = nn.Linear(sum(num_channels), 2)\n",
    "        # 时序最大池化层没有权重，所以可以共用一个实例\n",
    "        self.pool = GlobalMaxPool1d()\n",
    "        self.convs = nn.ModuleList()  # 创建多个一维卷积层\n",
    "        for c, k in zip(num_channels, kernel_sizes):\n",
    "            self.convs.append(nn.Conv1d(in_channels = embedding_dim, \n",
    "                                        out_channels = c, \n",
    "                                        kernel_size = k))\n",
    "\n",
    "    def forward(self, sentence):\n",
    "        embeds = self.word_embeddings(sentence)\n",
    "        embeds = embeds.permute(0, 2, 1)\n",
    "        # 对于每个一维卷积层，在时序最大池化后会得到一个形状为(批量大小, 通道大小, 1)的\n",
    "        # Tensor。使用flatten函数去掉最后一维，然后在通道维上连结\n",
    "        encoding = torch.cat([self.pool(F.relu(conv(embeds))).squeeze(-1) for conv in self.convs], dim=1)\n",
    "        # 应用丢弃法后使用全连接层得到输出\n",
    "        outputs = self.decoder(self.dropout(encoding))\n",
    "        return outputs\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TextCNN(\n",
      "  (word_embeddings): Embedding(34841, 100)\n",
      "  (dropout): Dropout(p=0.5, inplace=False)\n",
      "  (decoder): Linear(in_features=300, out_features=2, bias=True)\n",
      "  (pool): GlobalMaxPool1d()\n",
      "  (convs): ModuleList(\n",
      "    (0): Conv1d(100, 100, kernel_size=(3,), stride=(1,))\n",
      "    (1): Conv1d(100, 100, kernel_size=(4,), stride=(1,))\n",
      "    (2): Conv1d(100, 100, kernel_size=(5,), stride=(1,))\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "embedding_dim, kernel_sizes, num_channels = 100, [3, 4, 5], [100, 100, 100]\n",
    "net = TextCNN(vocab_size, embedding_dim, kernel_sizes, num_channels)\n",
    "print(net)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_accuracy(data_iter, net):\n",
    "    acc_sum, n = 0.0, 0\n",
    "    with torch.no_grad():\n",
    "        for batch_idx, batch in enumerate(data_iter):\n",
    "            X, y = batch.text, batch.label\n",
    "            X = X.permute(1, 0)\n",
    "            y.data.sub_(1)  #X转置 y为啥要减1\n",
    "            if isinstance(net, torch.nn.Module):\n",
    "                net.eval() # 评估模式, 这会关闭dropout\n",
    "                acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()\n",
    "                net.train() # 改回训练模式\n",
    "            else: # 自定义的模型, 3.13节之后不会用到, 不考虑GPU\n",
    "                if('is_training' in net.__code__.co_varnames): # 如果有is_training这个参数\n",
    "                    # 将is_training设置成False\n",
    "                    acc_sum += (net(X, is_training=False).argmax(dim=1) == y).float().sum().item() \n",
    "                else:\n",
    "                    acc_sum += (net(X).argmax(dim=1) == y).float().sum().item() \n",
    "            n += y.shape[0]\n",
    "    return acc_sum / n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(train_iter, test_iter, net, loss, optimizer, num_epochs):\n",
    "    batch_count = 0\n",
    "    for epoch in range(num_epochs):\n",
    "        train_l_sum, train_acc_sum, n, start = 0.0, 0.0, 0, time.time()\n",
    "        for batch_idx, batch in enumerate(data_iter):\n",
    "            X, y = batch.text, batch.label\n",
    "            X = X.permute(1, 0)\n",
    "            y.data.sub_(1)  #X转置 y为啥要减1\n",
    "            y_hat = net(X)\n",
    "            l = loss(y_hat, y)\n",
    "            optimizer.zero_grad()\n",
    "            l.backward()\n",
    "            optimizer.step()\n",
    "            train_l_sum += l.item()\n",
    "            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()\n",
    "            n += y.shape[0]\n",
    "            batch_count += 1\n",
    "        test_acc = evaluate_accuracy(test_iter, net)\n",
    "        print(\n",
    "            'epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'\n",
    "            % (epoch + 1, train_l_sum / batch_count, train_acc_sum / n,\n",
    "               test_acc, time.time() - start))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'data_iter' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-18-12384213677e>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      2\u001b[0m \u001b[0moptimizer\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0moptim\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mAdam\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnet\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlr\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mlr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      3\u001b[0m \u001b[0mloss\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mCrossEntropyLoss\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 4\u001b[1;33m \u001b[0mtrain\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtrain_iter\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mval_iter\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnet\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mloss\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnum_epochs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[1;32m<ipython-input-17-281462bd4591>\u001b[0m in \u001b[0;36mtrain\u001b[1;34m(train_iter, test_iter, net, loss, optimizer, num_epochs)\u001b[0m\n\u001b[0;32m      3\u001b[0m     \u001b[1;32mfor\u001b[0m \u001b[0mepoch\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnum_epochs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      4\u001b[0m         \u001b[0mtrain_l_sum\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtrain_acc_sum\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mn\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mstart\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m0.0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m0.0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtime\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 5\u001b[1;33m         \u001b[1;32mfor\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbatch\u001b[0m \u001b[1;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdata_iter\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      6\u001b[0m             \u001b[0mX\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mbatch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtext\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbatch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlabel\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      7\u001b[0m             \u001b[0mX\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mX\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpermute\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mNameError\u001b[0m: name 'data_iter' is not defined"
     ]
    }
   ],
   "source": [
    "lr, num_epochs = 0.001, 5\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
    "loss = nn.CrossEntropyLoss()\n",
    "train(train_iter, val_iter, net, loss, optimizer, num_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "oldHeight": 263.875222,
   "position": {
    "height": "40px",
    "left": "672.097px",
    "right": "20px",
    "top": "40px",
    "width": "350px"
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "varInspector_section_display": "none",
   "window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
